{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "source": [
    "# Sequence classification model for IMDB Sentiment Analysis\n",
    "(c) Deniz Yuret, 2019\n",
    "* Objectives: Learn the structure of the IMDB dataset and train a simple RNN model.\n",
    "* Prerequisites: [RNN models](60.rnn.ipynb)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set display width, load packages, import symbols\n",
    "ENV[\"COLUMNS\"] = 72\n",
    "using Pkg; haskey(Pkg.installed(),\"Knet\") || Pkg.add(\"Knet\")\n",
    "using Statistics: mean\n",
    "using Knet: Knet, AutoGrad, RNN, param, dropout, minibatch, nll, accuracy, progress!, adam, save, load, gc"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1.0e-8"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Set constants for the model and training\n",
    "EPOCHS=3          # Number of training epochs\n",
    "BATCHSIZE=64      # Number of instances in a minibatch\n",
    "EMBEDSIZE=125     # Word embedding size\n",
    "NUMHIDDEN=100     # Hidden layer size\n",
    "MAXLEN=150        # maximum size of the word sequence, pad shorter sequences, truncate longer ones\n",
    "VOCABSIZE=30000   # maximum vocabulary size, keep the most frequent 30K, map the rest to UNK token\n",
    "NUMCLASS=2        # number of output classes\n",
    "DROPOUT=0.5       # Dropout rate\n",
    "LR=0.001          # Learning rate\n",
    "BETA_1=0.9        # Adam optimization parameter\n",
    "BETA_2=0.999      # Adam optimization parameter\n",
    "EPS=1e-08         # Adam optimization parameter"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load and view data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "imdb"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "include(Knet.dir(\"data\",\"imdb.jl\"))   # defines imdb loader"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "slideshow": {
     "slide_type": "skip"
    }
   },
   "outputs": [
    {
     "data": {
      "text/latex": [
       "\\begin{verbatim}\n",
       "imdb()\n",
       "\\end{verbatim}\n",
       "Load the IMDB Movie reviews sentiment classification dataset from https://keras.io/datasets and return (xtrn,ytrn,xtst,ytst,dict) tuple.\n",
       "\n",
       "\\section{Keyword Arguments:}\n",
       "\\begin{itemize}\n",
       "\\item url=https://s3.amazonaws.com/text-datasets: where to download the data (imdb.npz) from.\n",
       "\n",
       "\n",
       "\\item dir=Pkg.dir(\"Knet/data\"): where to cache the data.\n",
       "\n",
       "\n",
       "\\item maxval=nothing: max number of token values to include. Words are ranked by how often they occur (in the training set) and only the most frequent words are kept. nothing means keep all, equivalent to maxval = vocabSize + pad + stoken.\n",
       "\n",
       "\n",
       "\\item maxlen=nothing: truncate sequences after this length. nothing means do not truncate.\n",
       "\n",
       "\n",
       "\\item seed=0: random seed for sample shuffling. Use system seed if 0.\n",
       "\n",
       "\n",
       "\\item pad=true: whether to pad short sequences (padding is done at the beginning of sequences). pad\\_token = maxval.\n",
       "\n",
       "\n",
       "\\item stoken=true: whether to add a start token to the beginning of each sequence. start\\_token = maxval - pad.\n",
       "\n",
       "\n",
       "\\item oov=true: whether to replace words >= oov\\emph{token with oov}token (the alternative is to skip them). oov\\_token = maxval - pad - stoken.\n",
       "\n",
       "\\end{itemize}\n"
      ],
      "text/markdown": [
       "```\n",
       "imdb()\n",
       "```\n",
       "\n",
       "Load the IMDB Movie reviews sentiment classification dataset from https://keras.io/datasets and return (xtrn,ytrn,xtst,ytst,dict) tuple.\n",
       "\n",
       "# Keyword Arguments:\n",
       "\n",
       "  * url=https://s3.amazonaws.com/text-datasets: where to download the data (imdb.npz) from.\n",
       "  * dir=Pkg.dir(\"Knet/data\"): where to cache the data.\n",
       "  * maxval=nothing: max number of token values to include. Words are ranked by how often they occur (in the training set) and only the most frequent words are kept. nothing means keep all, equivalent to maxval = vocabSize + pad + stoken.\n",
       "  * maxlen=nothing: truncate sequences after this length. nothing means do not truncate.\n",
       "  * seed=0: random seed for sample shuffling. Use system seed if 0.\n",
       "  * pad=true: whether to pad short sequences (padding is done at the beginning of sequences). pad_token = maxval.\n",
       "  * stoken=true: whether to add a start token to the beginning of each sequence. start_token = maxval - pad.\n",
       "  * oov=true: whether to replace words >= oov*token with oov*token (the alternative is to skip them). oov_token = maxval - pad - stoken.\n"
      ],
      "text/plain": [
       "\u001b[36m  imdb()\u001b[39m\n",
       "\n",
       "  Load the IMDB Movie reviews sentiment classification dataset from\n",
       "  https://keras.io/datasets and return (xtrn,ytrn,xtst,ytst,dict)\n",
       "  tuple.\n",
       "\n",
       "\u001b[1m  Keyword Arguments:\u001b[22m\n",
       "\u001b[1m  ≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡≡\u001b[22m\n",
       "\n",
       "    •    url=https://s3.amazonaws.com/text-datasets: where to\n",
       "        download the data (imdb.npz) from.\n",
       "\n",
       "    •    dir=Pkg.dir(\"Knet/data\"): where to cache the data.\n",
       "\n",
       "    •    maxval=nothing: max number of token values to include.\n",
       "        Words are ranked by how often they occur (in the training\n",
       "        set) and only the most frequent words are kept. nothing\n",
       "        means keep all, equivalent to maxval = vocabSize + pad +\n",
       "        stoken.\n",
       "\n",
       "    •    maxlen=nothing: truncate sequences after this length.\n",
       "        nothing means do not truncate.\n",
       "\n",
       "    •    seed=0: random seed for sample shuffling. Use system seed\n",
       "        if 0.\n",
       "\n",
       "    •    pad=true: whether to pad short sequences (padding is done\n",
       "        at the beginning of sequences). pad_token = maxval.\n",
       "\n",
       "    •    stoken=true: whether to add a start token to the beginning\n",
       "        of each sequence. start_token = maxval - pad.\n",
       "\n",
       "    •    oov=true: whether to replace words >= oov\u001b[4mtoken with\n",
       "        oov\u001b[24mtoken (the alternative is to skip them). oov_token =\n",
       "        maxval - pad - stoken."
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "@doc imdb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "┌ Info: Loading IMDB...\n",
      "└ @ Main /home/jrun/.julia/packages/Knet/T1oum/data/imdb.jl:57\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      " 12.955948 seconds (29.36 M allocations: 1.497 GiB, 7.69% gc time)\n"
     ]
    }
   ],
   "source": [
    "@time (xtrn,ytrn,xtst,ytst,imdbdict)=imdb(maxlen=MAXLEN,maxval=VOCABSIZE);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "25000-element Array{Array{Int32,1},1}\n",
      "25000-element Array{Int8,1}\n",
      "25000-element Array{Array{Int32,1},1}\n",
      "25000-element Array{Int8,1}\n",
      "Dict{String,Int32} with 88584 entries\n"
     ]
    }
   ],
   "source": [
    "println.(summary.((xtrn,ytrn,xtst,ytst,imdbdict)));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1×150 LinearAlgebra.Adjoint{Int32,Array{Int32,1}}:\n",
       " 2  734  28  9  761  37  32  2686  …  2165  163  2018  2828  165  49"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Words are encoded with integers\n",
    "rand(xtrn)'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1×25000 LinearAlgebra.Adjoint{Int64,Array{Int64,1}}:\n",
       " 150  150  150  150  150  150  150  …  150  150  150  150  150  150"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Each word sequence is padded or truncated to length 150\n",
    "length.(xtrn)'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "slideshow": {
     "slide_type": "slide"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "reviewstring (generic function with 2 methods)"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Define a function that can print the actual words:\n",
    "imdbvocab = Array{String}(undef,length(imdbdict))\n",
    "for (k,v) in imdbdict; imdbvocab[v]=k; end\n",
    "imdbvocab[VOCABSIZE-2:VOCABSIZE] = [\"<unk>\",\"<s>\",\"<pad>\"]\n",
    "function reviewstring(x,y=0)\n",
    "    x = x[x.!=VOCABSIZE] # remove pads\n",
    "    \"\"\"$((\"Sample\",\"Negative\",\"Positive\")[y+1]) review:\\n$(join(imdbvocab[x],\" \"))\"\"\"\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Negative review:\n",
      "<s> i saw this movie for 2 reasons i like gerard butler and christopher plummer unfortunately these poor men were forced to carry a pretty dumb movie i liked the idea that dracula is actually a reincarnation of judas <unk> because it does explain his disdain for all things christian but there was so much camp that this idea was not realized as much as it could have been i see this movie more as a way for the talented gerard butler to pay his dues before being truly recognized and a way for the legendary christopher plummer to remind the public me and the 5 other people who saw this film that he still exists i actually enjoyed the special features on the dvd more than the movie itself\n"
     ]
    }
   ],
   "source": [
    "# Hit Ctrl-Enter to see random reviews:\n",
    "r = rand(1:length(xtrn))\n",
    "println(reviewstring(xtrn[r],ytrn[r]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1×25000 LinearAlgebra.Adjoint{Int8,Array{Int8,1}}:\n",
       " 2  1  1  2  2  2  1  2  2  2  1  …  2  1  2  1  1  1  1  2  2  2  2"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Here are the labels: 1=negative, 2=positive\n",
    "ytrn'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "slideshow": {
     "slide_type": "fragment"
    }
   },
   "source": [
    "## Define the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "struct SequenceClassifier; input; rnn; output; pdrop; end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "SequenceClassifier"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "SequenceClassifier(input::Int, embed::Int, hidden::Int, output::Int; pdrop=0) =\n",
    "    SequenceClassifier(param(embed,input), RNN(embed,hidden,rnnType=:gru), param(output,hidden), pdrop)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "function (sc::SequenceClassifier)(input)\n",
    "    embed = sc.input[:, permutedims(hcat(input...))]\n",
    "    embed = dropout(embed,sc.pdrop)\n",
    "    hidden = sc.rnn(embed)\n",
    "    hidden = dropout(hidden,sc.pdrop)\n",
    "    return sc.output * hidden[:,:,end]\n",
    "end\n",
    "\n",
    "(sc::SequenceClassifier)(input,output) = nll(sc(input),output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Experiment"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(390, 390)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dtrn = minibatch(xtrn,ytrn,BATCHSIZE;shuffle=true)\n",
    "dtst = minibatch(xtst,ytst,BATCHSIZE)\n",
    "length.((dtrn,dtst))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "trainresults (generic function with 1 method)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# For running experiments\n",
    "function trainresults(file,maker; o...)\n",
    "    if (print(\"Train from scratch? \"); readline()[1]=='y')\n",
    "        model = maker()\n",
    "        progress!(adam(model,repeat(dtrn,EPOCHS);lr=LR,beta1=BETA_1,beta2=BETA_2,eps=EPS))\n",
    "        Knet.save(file,\"model\",model)\n",
    "        Knet.gc() # To save gpu memory\n",
    "    else\n",
    "        isfile(file) || download(\"http://people.csail.mit.edu/deniz/models/tutorial/$file\",file)\n",
    "        model = Knet.load(file,\"model\")\n",
    "    end\n",
    "    return model\n",
    "end"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "maker (generic function with 1 method)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "maker() = SequenceClassifier(VOCABSIZE,EMBEDSIZE,NUMHIDDEN,NUMCLASS,pdrop=DROPOUT)\n",
    "# model = maker()\n",
    "# nll(model,dtrn), nll(model,dtst), accuracy(model,dtrn), accuracy(model,dtst)\n",
    "# (0.69312066f0, 0.69312423f0, 0.5135817307692307, 0.5096153846153846)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train from scratch? stdin> n\n"
     ]
    }
   ],
   "source": [
    "# 2.51e-01  100.00%┣████████████████████┫ 1170/1170 [00:16/00:16, 75.46i/s]\n",
    "model = trainresults(\"imdbmodel113.jld2\",maker);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "#nll(model,dtrn), nll(model,dtst), accuracy(model,dtrn), accuracy(model,dtst)\n",
    "# (0.059155148f0, 0.3877507f0, 0.9846153846153847, 0.8583733974358975)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Playground"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "str2ids (generic function with 1 method)"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictstring(x)=\"\\nPrediction: \" * (\"Negative\",\"Positive\")[argmax(Array(vec(model([x]))))]\n",
    "UNK = VOCABSIZE-2\n",
    "str2ids(s::String)=[(i=get(imdbdict,w,UNK); i>=UNK ? UNK : i) for w in split(lowercase(s))]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Positive review:\n",
      "black <unk> and have never forgotten it i don't think i ever want to see it again it was just too much br br spoiler ahead the scene which will not leave me is when <unk> meets a prostitute who has to <unk> her own fetus you don't see her do it but you get a quick glance at what she got out it's almost 20 years later and just recalling that scene upsets me spoiler end br br the movie gets more brutal as it goes along and ends the only way it can what's all the more harrowing is stories like this really did happen in brazil in 1981 and are still happening today br br a harrowing brutal film but it should be seen if you can handle it i'm surprised this got an r rating i've seen x rated film that are less graphic a 10\n",
      "\n",
      "Prediction: Positive\n"
     ]
    }
   ],
   "source": [
    "# Here we can see predictions for random reviews from the test set; hit Ctrl-Enter to sample:\n",
    "r = rand(1:length(xtst))\n",
    "println(reviewstring(xtst[r],ytst[r]))\n",
    "println(predictstring(xtst[r]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "stdin> awesome\n",
      "\n",
      "Prediction: Positive\n"
     ]
    }
   ],
   "source": [
    "# Here the user can enter their own reviews and classify them:\n",
    "println(predictstring(str2ids(readline(stdin))))"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "julia.ipynb",
   "provenance": [],
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Julia 1.0.3",
   "language": "julia",
   "name": "julia-1.0"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "1.0.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
